Skip to content

Improve ROCm SDPA fallback behavior and MI350 support#1

Open
keithloweryamd wants to merge 2 commits intoandyluo7:masterfrom
keithloweryamd:rocm-upstream-candidate
Open

Improve ROCm SDPA fallback behavior and MI350 support#1
keithloweryamd wants to merge 2 commits intoandyluo7:masterfrom
keithloweryamd:rocm-upstream-candidate

Conversation

@keithloweryamd
Copy link

Summary

This PR carries a small ROCm-focused cleanup that keeps the training path stable on AMD Instinct while preserving the existing CUDA path.

  • allow overriding the autoresearch cache root with AUTORESEARCH_CACHE_DIR
  • make the ROCm SDPA path try EFFICIENT_ATTENTION first, then fall back to FLASH_ATTENTION and MATH
  • keep logits in bf16 through the softcap and cross-entropy path instead of forcing an intermediate float()
  • add MI350X / MI355X peak-BF16-FLOPS entries so MFU reporting is normalized against the correct hardware
  • force WINDOW_PATTERN to L on ROCm, since the current ROCm SDPA path is full-causal attention and does not implement the sliding-window variants used on the CUDA FA3 path

The changes are intentionally narrow. They do not include local profiling hooks, workspace-specific wrapper scripts, TunableOp solution files, or other investigation artifacts.

Why

On the target MI350X environment, the previous ROCm path was functional but left performance on the table and could select slower SDPA behavior depending on the runtime/backend combination. The fallback logic makes the desired efficient attention backend explicit while still keeping training runnable if that backend is unavailable.

The WINDOW_PATTERN adjustment also makes the ROCm behavior explicit rather than silently pretending to run the CUDA-side SSSL pattern when the actual backend is full causal attention.

MI350X Performance

Using the saved 300-second end-to-end baselines from the investigation:

  • pristine baseline: about 722-731 ms/step, ~717k-726k tok/s, 222.3M tokens in 300s
  • current optimized baseline with these integrated ROCm/runtime changes: about 597-606 ms/step, ~865k-878k tok/s, 267.9M tokens in 300s

That corresponds to:

  • about 1.21x end-to-end speedup
  • about +20.5% more tokens processed in the same 5-minute budget
  • peak VRAM reduced from 105632.2 MB to 97440.2 MB

There was also a transient regression during the ROCm attention work when forcing a less favorable SDPA path; the final runtime fallback in this PR is the change that recovered and stabilized the faster path.

Validation

  • python -m py_compile prepare.py train.py
  • MI350X measurements from the saved train.py performance ledger in dev/performance_checkpoints.md

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant